# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import sys
import time
import torch
import json
import itertools
import accelerate
import torch.distributed as dist
import torch.nn.functional as F
from tqdm import tqdm
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter

from torch.optim import AdamW
from torch.optim.lr_scheduler import ExponentialLR

from librosa.filters import mel as librosa_mel_fn

from accelerate.logging import get_logger
from pathlib import Path

from utils.io import save_audio
from utils.data_utils import *
from utils.util import (
    Logger,
    ValueWindow,
    remove_older_ckpt,
    set_all_random_seed,
    save_config,
)
from utils.mel import extract_mel_features
from models.vocoders.vocoder_trainer import VocoderTrainer
from models.vocoders.gan.gan_vocoder_dataset import (
    GANVocoderDataset,
    GANVocoderCollator,
)

from models.vocoders.gan.generator.bigvgan import BigVGAN
from models.vocoders.gan.generator.hifigan import HiFiGAN
from models.vocoders.gan.generator.melgan import MelGAN
from models.vocoders.gan.generator.nsfhifigan import NSFHiFiGAN
from models.vocoders.gan.generator.apnet import APNet

from models.vocoders.gan.discriminator.mpd import MultiPeriodDiscriminator
from models.vocoders.gan.discriminator.mrd import MultiResolutionDiscriminator
from models.vocoders.gan.discriminator.mssbcqtd import MultiScaleSubbandCQTDiscriminator
from models.vocoders.gan.discriminator.msd import MultiScaleDiscriminator
from models.vocoders.gan.discriminator.msstftd import MultiScaleSTFTDiscriminator

from models.vocoders.gan.gan_vocoder_inference import vocoder_inference

supported_generators = {
    "bigvgan": BigVGAN,
    "hifigan": HiFiGAN,
    "melgan": MelGAN,
    "nsfhifigan": NSFHiFiGAN,
    "apnet": APNet,
}

supported_discriminators = {
    "mpd": MultiPeriodDiscriminator,
    "msd": MultiScaleDiscriminator,
    "mrd": MultiResolutionDiscriminator,
    "msstftd": MultiScaleSTFTDiscriminator,
    "mssbcqtd": MultiScaleSubbandCQTDiscriminator,
}


class GANVocoderTrainer(VocoderTrainer):
    def __init__(self, args, cfg):
        super().__init__()

        self.args = args
        self.cfg = cfg

        cfg.exp_name = args.exp_name

        # Init accelerator
        self._init_accelerator()
        self.accelerator.wait_for_everyone()

        # Init logger
        with self.accelerator.main_process_first():
            self.logger = get_logger(args.exp_name, log_level=args.log_level)

        self.logger.info("=" * 56)
        self.logger.info("||\t\t" + "New training process started." + "\t\t||")
        self.logger.info("=" * 56)
        self.logger.info("\n")
        self.logger.debug(f"Using {args.log_level.upper()} logging level.")
        self.logger.info(f"Experiment name: {args.exp_name}")
        self.logger.info(f"Experiment directory: {self.exp_dir}")
        self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
        if self.accelerator.is_main_process:
            os.makedirs(self.checkpoint_dir, exist_ok=True)
        self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")

        # Init training status
        self.batch_count: int = 0
        self.step: int = 0
        self.epoch: int = 0

        self.max_epoch = (
            self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
        )
        self.logger.info(
            "Max epoch: {}".format(
                self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
            )
        )

        # Check potential erorrs
        if self.accelerator.is_main_process:
            self._check_basic_configs()
            self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
            self.checkpoints_path = [
                [] for _ in range(len(self.save_checkpoint_stride))
            ]
            self.run_eval = self.cfg.train.run_eval

        # Set random seed
        with self.accelerator.main_process_first():
            start = time.monotonic_ns()
            self._set_random_seed(self.cfg.train.random_seed)
            end = time.monotonic_ns()
            self.logger.debug(
                f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
            )
            self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")

        # Build dataloader
        with self.accelerator.main_process_first():
            self.logger.info("Building dataset...")
            start = time.monotonic_ns()
            self.train_dataloader, self.valid_dataloader = self._build_dataloader()
            end = time.monotonic_ns()
            self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")

        # Build model
        with self.accelerator.main_process_first():
            self.logger.info("Building model...")
            start = time.monotonic_ns()
            self.generator, self.discriminators = self._build_model()
            end = time.monotonic_ns()
            self.logger.debug(self.generator)
            for _, discriminator in self.discriminators.items():
                self.logger.debug(discriminator)
            self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
            self.logger.info(f"Model parameters: {self._count_parameters()/1e6:.2f}M")

        # Build optimizers and schedulers
        with self.accelerator.main_process_first():
            self.logger.info("Building optimizer and scheduler...")
            start = time.monotonic_ns()
            (
                self.generator_optimizer,
                self.discriminator_optimizer,
            ) = self._build_optimizer()
            (
                self.generator_scheduler,
                self.discriminator_scheduler,
            ) = self._build_scheduler()
            end = time.monotonic_ns()
            self.logger.info(
                f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
            )

        # Accelerator preparing
        self.logger.info("Initializing accelerate...")
        start = time.monotonic_ns()
        (
            self.train_dataloader,
            self.valid_dataloader,
            self.generator,
            self.generator_optimizer,
            self.discriminator_optimizer,
            self.generator_scheduler,
            self.discriminator_scheduler,
        ) = self.accelerator.prepare(
            self.train_dataloader,
            self.valid_dataloader,
            self.generator,
            self.generator_optimizer,
            self.discriminator_optimizer,
            self.generator_scheduler,
            self.discriminator_scheduler,
        )
        for key, discriminator in self.discriminators.items():
            self.discriminators[key] = self.accelerator.prepare_model(discriminator)
        end = time.monotonic_ns()
        self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")

        # Build criterions
        with self.accelerator.main_process_first():
            self.logger.info("Building criterion...")
            start = time.monotonic_ns()
            self.criterions = self._build_criterion()
            end = time.monotonic_ns()
            self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")

        # Resume checkpoints
        with self.accelerator.main_process_first():
            if args.resume_type:
                self.logger.info("Resuming from checkpoint...")
                start = time.monotonic_ns()
                ckpt_path = Path(args.checkpoint)
                if self._is_valid_pattern(ckpt_path.parts[-1]):
                    ckpt_path = self._load_model(
                        None, args.checkpoint, args.resume_type
                    )
                else:
                    ckpt_path = self._load_model(
                        args.checkpoint, resume_type=args.resume_type
                    )
                end = time.monotonic_ns()
                self.logger.info(
                    f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
                )
                self.checkpoints_path = json.load(
                    open(os.path.join(ckpt_path, "ckpts.json"), "r")
                )

            self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
            if self.accelerator.is_main_process:
                os.makedirs(self.checkpoint_dir, exist_ok=True)
            self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")

        # Save config
        self.config_save_path = os.path.join(self.exp_dir, "args.json")

    def _build_dataset(self):
        return GANVocoderDataset, GANVocoderCollator

    def _build_criterion(self):
        class feature_criterion(torch.nn.Module):
            def __init__(self, cfg):
                super(feature_criterion, self).__init__()
                self.cfg = cfg
                self.l1Loss = torch.nn.L1Loss(reduction="mean")
                self.l2Loss = torch.nn.MSELoss(reduction="mean")
                self.relu = torch.nn.ReLU()

            def __call__(self, fmap_r, fmap_g):
                loss = 0

                if self.cfg.model.generator in [
                    "hifigan",
                    "nsfhifigan",
                    "bigvgan",
                    "apnet",
                ]:
                    for dr, dg in zip(fmap_r, fmap_g):
                        for rl, gl in zip(dr, dg):
                            loss += torch.mean(torch.abs(rl - gl))

                    loss = loss * 2
                elif self.cfg.model.generator in ["melgan"]:
                    for dr, dg in zip(fmap_r, fmap_g):
                        for rl, gl in zip(dr, dg):
                            loss += self.l1Loss(rl, gl)

                    loss = loss * 10
                elif self.cfg.model.generator in ["codec"]:
                    for dr, dg in zip(fmap_r, fmap_g):
                        for rl, gl in zip(dr, dg):
                            loss = loss + self.l1Loss(rl, gl) / torch.mean(
                                torch.abs(rl)
                            )

                    KL_scale = len(fmap_r) * len(fmap_r[0])

                    loss = 3 * loss / KL_scale
                else:
                    raise NotImplementedError

                return loss

        class discriminator_criterion(torch.nn.Module):
            def __init__(self, cfg):
                super(discriminator_criterion, self).__init__()
                self.cfg = cfg
                self.l1Loss = torch.nn.L1Loss(reduction="mean")
                self.l2Loss = torch.nn.MSELoss(reduction="mean")
                self.relu = torch.nn.ReLU()

            def __call__(self, disc_real_outputs, disc_generated_outputs):
                loss = 0
                r_losses = []
                g_losses = []

                if self.cfg.model.generator in [
                    "hifigan",
                    "nsfhifigan",
                    "bigvgan",
                    "apnet",
                ]:
                    for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
                        r_loss = torch.mean((1 - dr) ** 2)
                        g_loss = torch.mean(dg**2)
                        loss += r_loss + g_loss
                        r_losses.append(r_loss.item())
                        g_losses.append(g_loss.item())
                elif self.cfg.model.generator in ["melgan"]:
                    for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
                        r_loss = torch.mean(self.relu(1 - dr))
                        g_loss = torch.mean(self.relu(1 + dg))
                        loss = loss + r_loss + g_loss
                        r_losses.append(r_loss.item())
                        g_losses.append(g_loss.item())
                elif self.cfg.model.generator in ["codec"]:
                    for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
                        r_loss = torch.mean(self.relu(1 - dr))
                        g_loss = torch.mean(self.relu(1 + dg))
                        loss = loss + r_loss + g_loss
                        r_losses.append(r_loss.item())
                        g_losses.append(g_loss.item())

                    loss = loss / len(disc_real_outputs)
                else:
                    raise NotImplementedError

                return loss, r_losses, g_losses

        class generator_criterion(torch.nn.Module):
            def __init__(self, cfg):
                super(generator_criterion, self).__init__()
                self.cfg = cfg
                self.l1Loss = torch.nn.L1Loss(reduction="mean")
                self.l2Loss = torch.nn.MSELoss(reduction="mean")
                self.relu = torch.nn.ReLU()

            def __call__(self, disc_outputs):
                loss = 0
                gen_losses = []

                if self.cfg.model.generator in [
                    "hifigan",
                    "nsfhifigan",
                    "bigvgan",
                    "apnet",
                ]:
                    for dg in disc_outputs:
                        l = torch.mean((1 - dg) ** 2)
                        gen_losses.append(l)
                        loss += l
                elif self.cfg.model.generator in ["melgan"]:
                    for dg in disc_outputs:
                        l = -torch.mean(dg)
                        gen_losses.append(l)
                        loss += l
                elif self.cfg.model.generator in ["codec"]:
                    for dg in disc_outputs:
                        l = torch.mean(self.relu(1 - dg)) / len(disc_outputs)
                        gen_losses.append(l)
                        loss += l
                else:
                    raise NotImplementedError

                return loss, gen_losses

        class mel_criterion(torch.nn.Module):
            def __init__(self, cfg):
                super(mel_criterion, self).__init__()
                self.cfg = cfg
                self.l1Loss = torch.nn.L1Loss(reduction="mean")
                self.l2Loss = torch.nn.MSELoss(reduction="mean")
                self.relu = torch.nn.ReLU()

            def __call__(self, y_gt, y_pred):
                loss = 0

                if self.cfg.model.generator in [
                    "hifigan",
                    "nsfhifigan",
                    "bigvgan",
                    "melgan",
                    "codec",
                    "apnet",
                ]:
                    y_gt_mel = extract_mel_features(y_gt, self.cfg.preprocess)
                    y_pred_mel = extract_mel_features(
                        y_pred.squeeze(1), self.cfg.preprocess
                    )

                    loss = self.l1Loss(y_gt_mel, y_pred_mel) * 45
                else:
                    raise NotImplementedError

                return loss

        class wav_criterion(torch.nn.Module):
            def __init__(self, cfg):
                super(wav_criterion, self).__init__()
                self.cfg = cfg
                self.l1Loss = torch.nn.L1Loss(reduction="mean")
                self.l2Loss = torch.nn.MSELoss(reduction="mean")
                self.relu = torch.nn.ReLU()

            def __call__(self, y_gt, y_pred):
                loss = 0

                if self.cfg.model.generator in [
                    "hifigan",
                    "nsfhifigan",
                    "bigvgan",
                    "apnet",
                ]:
                    loss = self.l2Loss(y_gt, y_pred.squeeze(1)) * 100
                elif self.cfg.model.generator in ["melgan"]:
                    loss = self.l1Loss(y_gt, y_pred.squeeze(1)) / 10
                elif self.cfg.model.generator in ["codec"]:
                    loss = self.l1Loss(y_gt, y_pred.squeeze(1)) + self.l2Loss(
                        y_gt, y_pred.squeeze(1)
                    )
                    loss /= 10
                else:
                    raise NotImplementedError

                return loss

        class phase_criterion(torch.nn.Module):
            def __init__(self, cfg):
                super(phase_criterion, self).__init__()
                self.cfg = cfg
                self.l1Loss = torch.nn.L1Loss(reduction="mean")
                self.l2Loss = torch.nn.MSELoss(reduction="mean")
                self.relu = torch.nn.ReLU()

            def __call__(self, phase_gt, phase_pred):
                n_fft = self.cfg.preprocess.n_fft
                frames = phase_gt.size()[-1]

                GD_matrix = (
                    torch.triu(torch.ones(n_fft // 2 + 1, n_fft // 2 + 1), diagonal=1)
                    - torch.triu(torch.ones(n_fft // 2 + 1, n_fft // 2 + 1), diagonal=2)
                    - torch.eye(n_fft // 2 + 1)
                )
                GD_matrix = GD_matrix.to(phase_pred.device)

                GD_r = torch.matmul(phase_gt.permute(0, 2, 1), GD_matrix)
                GD_g = torch.matmul(phase_pred.permute(0, 2, 1), GD_matrix)

                PTD_matrix = (
                    torch.triu(torch.ones(frames, frames), diagonal=1)
                    - torch.triu(torch.ones(frames, frames), diagonal=2)
                    - torch.eye(frames)
                )
                PTD_matrix = PTD_matrix.to(phase_pred.device)

                PTD_r = torch.matmul(phase_gt, PTD_matrix)
                PTD_g = torch.matmul(phase_pred, PTD_matrix)

                IP_loss = torch.mean(-torch.cos(phase_gt - phase_pred))
                GD_loss = torch.mean(-torch.cos(GD_r - GD_g))
                PTD_loss = torch.mean(-torch.cos(PTD_r - PTD_g))

                return 100 * (IP_loss + GD_loss + PTD_loss)

        class amplitude_criterion(torch.nn.Module):
            def __init__(self, cfg):
                super(amplitude_criterion, self).__init__()
                self.cfg = cfg
                self.l1Loss = torch.nn.L1Loss(reduction="mean")
                self.l2Loss = torch.nn.MSELoss(reduction="mean")
                self.relu = torch.nn.ReLU()

            def __call__(self, log_amplitude_gt, log_amplitude_pred):
                amplitude_loss = self.l2Loss(log_amplitude_gt, log_amplitude_pred)

                return 45 * amplitude_loss

        class consistency_criterion(torch.nn.Module):
            def __init__(self, cfg):
                super(consistency_criterion, self).__init__()
                self.cfg = cfg
                self.l1Loss = torch.nn.L1Loss(reduction="mean")
                self.l2Loss = torch.nn.MSELoss(reduction="mean")
                self.relu = torch.nn.ReLU()

            def __call__(
                self,
                rea_gt,
                rea_pred,
                rea_pred_final,
                imag_gt,
                imag_pred,
                imag_pred_final,
            ):
                C_loss = torch.mean(
                    torch.mean(
                        (rea_pred - rea_pred_final) ** 2
                        + (imag_pred - imag_pred_final) ** 2,
                        (1, 2),
                    )
                )

                L_R = self.l1Loss(rea_gt, rea_pred)
                L_I = self.l1Loss(imag_gt, imag_pred)

                return 20 * (C_loss + 2.25 * (L_R + L_I))

        criterions = dict()
        for key in self.cfg.train.criterions:
            if key == "feature":
                criterions["feature"] = feature_criterion(self.cfg)
            elif key == "discriminator":
                criterions["discriminator"] = discriminator_criterion(self.cfg)
            elif key == "generator":
                criterions["generator"] = generator_criterion(self.cfg)
            elif key == "mel":
                criterions["mel"] = mel_criterion(self.cfg)
            elif key == "wav":
                criterions["wav"] = wav_criterion(self.cfg)
            elif key == "phase":
                criterions["phase"] = phase_criterion(self.cfg)
            elif key == "amplitude":
                criterions["amplitude"] = amplitude_criterion(self.cfg)
            elif key == "consistency":
                criterions["consistency"] = consistency_criterion(self.cfg)
            else:
                raise NotImplementedError

        return criterions

    def _build_model(self):
        generator = supported_generators[self.cfg.model.generator](self.cfg)
        discriminators = dict()
        for key in self.cfg.model.discriminators:
            discriminators[key] = supported_discriminators[key](self.cfg)

        return generator, discriminators

    def _build_optimizer(self):
        optimizer_params_generator = [dict(params=self.generator.parameters())]
        generator_optimizer = AdamW(
            optimizer_params_generator,
            lr=self.cfg.train.adamw.lr,
            betas=(self.cfg.train.adamw.adam_b1, self.cfg.train.adamw.adam_b2),
        )

        optimizer_params_discriminator = []
        for discriminator in self.discriminators.keys():
            optimizer_params_discriminator.append(
                dict(params=self.discriminators[discriminator].parameters())
            )
        discriminator_optimizer = AdamW(
            optimizer_params_discriminator,
            lr=self.cfg.train.adamw.lr,
            betas=(self.cfg.train.adamw.adam_b1, self.cfg.train.adamw.adam_b2),
        )

        return generator_optimizer, discriminator_optimizer

    def _build_scheduler(self):
        discriminator_scheduler = ExponentialLR(
            self.discriminator_optimizer,
            gamma=self.cfg.train.exponential_lr.lr_decay,
            last_epoch=self.epoch - 1,
        )

        generator_scheduler = ExponentialLR(
            self.generator_optimizer,
            gamma=self.cfg.train.exponential_lr.lr_decay,
            last_epoch=self.epoch - 1,
        )

        return generator_scheduler, discriminator_scheduler

    def train_loop(self):
        """Training process"""
        self.accelerator.wait_for_everyone()

        # Dump config
        if self.accelerator.is_main_process:
            self._dump_cfg(self.config_save_path)
        self.generator.train()
        for key in self.discriminators.keys():
            self.discriminators[key].train()
        self.generator_optimizer.zero_grad()
        self.discriminator_optimizer.zero_grad()

        # Sync and start training
        self.accelerator.wait_for_everyone()
        while self.epoch < self.max_epoch:
            self.logger.info("\n")
            self.logger.info("-" * 32)
            self.logger.info("Epoch {}: ".format(self.epoch))

            # Train and Validate
            train_total_loss, train_losses = self._train_epoch()
            for key, loss in train_losses.items():
                self.logger.info("  |- Train/{} Loss: {:.6f}".format(key, loss))
                self.accelerator.log(
                    {"Epoch/Train {} Loss".format(key): loss},
                    step=self.epoch,
                )
            valid_total_loss, valid_losses = self._valid_epoch()
            for key, loss in valid_losses.items():
                self.logger.info("  |- Valid/{} Loss: {:.6f}".format(key, loss))
                self.accelerator.log(
                    {"Epoch/Valid {} Loss".format(key): loss},
                    step=self.epoch,
                )
            self.accelerator.log(
                {
                    "Epoch/Train Total Loss": train_total_loss,
                    "Epoch/Valid Total Loss": valid_total_loss,
                },
                step=self.epoch,
            )

            # Update scheduler
            self.accelerator.wait_for_everyone()
            self.generator_scheduler.step()
            self.discriminator_scheduler.step()

            # Check save checkpoint interval
            run_eval = False
            if self.accelerator.is_main_process:
                save_checkpoint = False
                for i, num in enumerate(self.save_checkpoint_stride):
                    if self.epoch % num == 0:
                        save_checkpoint = True
                        run_eval |= self.run_eval[i]

            # Save checkpoints
            self.accelerator.wait_for_everyone()
            if self.accelerator.is_main_process and save_checkpoint:
                path = os.path.join(
                    self.checkpoint_dir,
                    "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
                        self.epoch, self.step, valid_total_loss
                    ),
                )
                self.accelerator.save_state(path)
                json.dump(
                    self.checkpoints_path,
                    open(os.path.join(path, "ckpts.json"), "w"),
                    ensure_ascii=False,
                    indent=4,
                )

            # Save eval audios
            self.accelerator.wait_for_everyone()
            if self.accelerator.is_main_process and run_eval:
                for i in range(len(self.valid_dataloader.dataset.eval_audios)):
                    if self.cfg.preprocess.use_frame_pitch:
                        eval_audio = self._inference(
                            self.valid_dataloader.dataset.eval_mels[i],
                            eval_pitch=self.valid_dataloader.dataset.eval_pitchs[i],
                            use_pitch=True,
                        )
                    else:
                        eval_audio = self._inference(
                            self.valid_dataloader.dataset.eval_mels[i]
                        )
                    path = os.path.join(
                        self.checkpoint_dir,
                        "epoch-{:04d}_step-{:07d}_loss-{:.6f}_eval_audio_{}.wav".format(
                            self.epoch,
                            self.step,
                            valid_total_loss,
                            self.valid_dataloader.dataset.eval_dataset_names[i],
                        ),
                    )
                    path_gt = os.path.join(
                        self.checkpoint_dir,
                        "epoch-{:04d}_step-{:07d}_loss-{:.6f}_eval_audio_{}_gt.wav".format(
                            self.epoch,
                            self.step,
                            valid_total_loss,
                            self.valid_dataloader.dataset.eval_dataset_names[i],
                        ),
                    )
                    save_audio(path, eval_audio, self.cfg.preprocess.sample_rate)
                    save_audio(
                        path_gt,
                        self.valid_dataloader.dataset.eval_audios[i],
                        self.cfg.preprocess.sample_rate,
                    )

            self.accelerator.wait_for_everyone()

            self.epoch += 1

        # Finish training
        self.accelerator.wait_for_everyone()
        path = os.path.join(
            self.checkpoint_dir,
            "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
                self.epoch, self.step, valid_total_loss
            ),
        )
        self.accelerator.save_state(path)

    def _train_epoch(self):
        """Training epoch. Should return average loss of a batch (sample) over
        one epoch. See ``train_loop`` for usage.
        """
        self.generator.train()
        for key, _ in self.discriminators.items():
            self.discriminators[key].train()

        epoch_losses: dict = {}
        epoch_total_loss: int = 0

        for batch in tqdm(
            self.train_dataloader,
            desc=f"Training Epoch {self.epoch}",
            unit="batch",
            colour="GREEN",
            leave=False,
            dynamic_ncols=True,
            smoothing=0.04,
            disable=not self.accelerator.is_main_process,
        ):
            # Get losses
            total_loss, losses = self._train_step(batch)
            self.batch_count += 1

            # Log info
            if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
                self.accelerator.log(
                    {
                        "Step/Generator Learning Rate": self.generator_optimizer.param_groups[
                            0
                        ][
                            "lr"
                        ],
                        "Step/Discriminator Learning Rate": self.discriminator_optimizer.param_groups[
                            0
                        ][
                            "lr"
                        ],
                    },
                    step=self.step,
                )
                for key, _ in losses.items():
                    self.accelerator.log(
                        {
                            "Step/Train {} Loss".format(key): losses[key],
                        },
                        step=self.step,
                    )

                if not epoch_losses:
                    epoch_losses = losses
                else:
                    for key, value in losses.items():
                        epoch_losses[key] += value
                epoch_total_loss += total_loss
                self.step += 1

        # Get and log total losses
        self.accelerator.wait_for_everyone()
        epoch_total_loss = (
            epoch_total_loss
            / len(self.train_dataloader)
            * self.cfg.train.gradient_accumulation_step
        )
        for key in epoch_losses.keys():
            epoch_losses[key] = (
                epoch_losses[key]
                / len(self.train_dataloader)
                * self.cfg.train.gradient_accumulation_step
            )
        return epoch_total_loss, epoch_losses

    def _train_step(self, data):
        """Training forward step. Should return average loss of a sample over
        one batch. Provoke ``_forward_step`` is recommended except for special case.
        See ``_train_epoch`` for usage.
        """
        # Init losses
        train_losses = {}
        total_loss = 0

        generator_losses = {}
        generator_total_loss = 0
        discriminator_losses = {}
        discriminator_total_loss = 0

        # Use input feature to get predictions
        mel_input = data["mel"]
        audio_gt = data["audio"]

        if self.cfg.preprocess.extract_amplitude_phase:
            logamp_gt = data["logamp"]
            pha_gt = data["pha"]
            rea_gt = data["rea"]
            imag_gt = data["imag"]

        if self.cfg.preprocess.use_frame_pitch:
            pitch_input = data["frame_pitch"]

        if self.cfg.preprocess.use_frame_pitch:
            pitch_input = pitch_input.float()
            audio_pred = self.generator.forward(mel_input, pitch_input)
        elif self.cfg.preprocess.extract_amplitude_phase:
            (
                logamp_pred,
                pha_pred,
                rea_pred,
                imag_pred,
                audio_pred,
            ) = self.generator.forward(mel_input)
            from utils.mel import amplitude_phase_spectrum

            _, _, rea_pred_final, imag_pred_final = amplitude_phase_spectrum(
                audio_pred.squeeze(1), self.cfg.preprocess
            )
        else:
            audio_pred = self.generator.forward(mel_input)

        # Calculate and BP Discriminator losses
        self.discriminator_optimizer.zero_grad()
        for key, _ in self.discriminators.items():
            y_r, y_g, _, _ = self.discriminators[key].forward(
                audio_gt.unsqueeze(1), audio_pred.detach()
            )
            (
                discriminator_losses["{}_discriminator".format(key)],
                _,
                _,
            ) = self.criterions["discriminator"](y_r, y_g)
            discriminator_total_loss += discriminator_losses[
                "{}_discriminator".format(key)
            ]

        self.accelerator.backward(discriminator_total_loss)
        self.discriminator_optimizer.step()

        # Calculate and BP Generator losses
        self.generator_optimizer.zero_grad()
        for key, _ in self.discriminators.items():
            y_r, y_g, f_r, f_g = self.discriminators[key].forward(
                audio_gt.unsqueeze(1), audio_pred
            )
            generator_losses["{}_feature".format(key)] = self.criterions["feature"](
                f_r, f_g
            )
            generator_losses["{}_generator".format(key)], _ = self.criterions[
                "generator"
            ](y_g)
            generator_total_loss += generator_losses["{}_feature".format(key)]
            generator_total_loss += generator_losses["{}_generator".format(key)]

        if "mel" in self.criterions.keys():
            generator_losses["mel"] = self.criterions["mel"](audio_gt, audio_pred)
            generator_total_loss += generator_losses["mel"]

        if "wav" in self.criterions.keys():
            generator_losses["wav"] = self.criterions["wav"](audio_gt, audio_pred)
            generator_total_loss += generator_losses["wav"]

        if "amplitude" in self.criterions.keys():
            generator_losses["amplitude"] = self.criterions["amplitude"](
                logamp_gt, logamp_pred
            )
            generator_total_loss += generator_losses["amplitude"]

        if "phase" in self.criterions.keys():
            generator_losses["phase"] = self.criterions["phase"](pha_gt, pha_pred)
            generator_total_loss += generator_losses["phase"]

        if "consistency" in self.criterions.keys():
            generator_losses["consistency"] = self.criterions["consistency"](
                rea_gt, rea_pred, rea_pred_final, imag_gt, imag_pred, imag_pred_final
            )
            generator_total_loss += generator_losses["consistency"]

        self.accelerator.backward(generator_total_loss)
        self.generator_optimizer.step()

        # Get the total losses
        total_loss = discriminator_total_loss + generator_total_loss
        train_losses.update(discriminator_losses)
        train_losses.update(generator_losses)

        for key, _ in train_losses.items():
            train_losses[key] = train_losses[key].item()

        return total_loss.item(), train_losses

    def _valid_epoch(self):
        """Testing epoch. Should return average loss of a batch (sample) over
        one epoch. See ``train_loop`` for usage.
        """
        self.generator.eval()
        for key, _ in self.discriminators.items():
            self.discriminators[key].eval()

        epoch_losses: dict = {}
        epoch_total_loss: int = 0

        for batch in tqdm(
            self.valid_dataloader,
            desc=f"Validating Epoch {self.epoch}",
            unit="batch",
            colour="GREEN",
            leave=False,
            dynamic_ncols=True,
            smoothing=0.04,
            disable=not self.accelerator.is_main_process,
        ):
            # Get losses
            total_loss, losses = self._valid_step(batch)

            # Log info
            for key, _ in losses.items():
                self.accelerator.log(
                    {
                        "Step/Valid {} Loss".format(key): losses[key],
                    },
                    step=self.step,
                )

            if not epoch_losses:
                epoch_losses = losses
            else:
                for key, value in losses.items():
                    epoch_losses[key] += value
            epoch_total_loss += total_loss

        # Get and log total losses
        self.accelerator.wait_for_everyone()
        epoch_total_loss = epoch_total_loss / len(self.valid_dataloader)
        for key in epoch_losses.keys():
            epoch_losses[key] = epoch_losses[key] / len(self.valid_dataloader)
        return epoch_total_loss, epoch_losses

    def _valid_step(self, data):
        """Testing forward step. Should return average loss of a sample over
        one batch. Provoke ``_forward_step`` is recommended except for special case.
        See ``_test_epoch`` for usage.
        """
        # Init losses
        valid_losses = {}
        total_loss = 0

        generator_losses = {}
        generator_total_loss = 0
        discriminator_losses = {}
        discriminator_total_loss = 0

        # Use feature inputs to get the predicted audio
        mel_input = data["mel"]
        audio_gt = data["audio"]

        if self.cfg.preprocess.extract_amplitude_phase:
            logamp_gt = data["logamp"]
            pha_gt = data["pha"]
            rea_gt = data["rea"]
            imag_gt = data["imag"]

        if self.cfg.preprocess.use_frame_pitch:
            pitch_input = data["frame_pitch"]

        if self.cfg.preprocess.use_frame_pitch:
            pitch_input = pitch_input.float()
            audio_pred = self.generator.forward(mel_input, pitch_input)
        elif self.cfg.preprocess.extract_amplitude_phase:
            (
                logamp_pred,
                pha_pred,
                rea_pred,
                imag_pred,
                audio_pred,
            ) = self.generator.forward(mel_input)
            from utils.mel import amplitude_phase_spectrum

            _, _, rea_pred_final, imag_pred_final = amplitude_phase_spectrum(
                audio_pred.squeeze(1), self.cfg.preprocess
            )
        else:
            audio_pred = self.generator.forward(mel_input)

        # Get Discriminator losses
        for key, _ in self.discriminators.items():
            y_r, y_g, _, _ = self.discriminators[key].forward(
                audio_gt.unsqueeze(1), audio_pred
            )
            (
                discriminator_losses["{}_discriminator".format(key)],
                _,
                _,
            ) = self.criterions["discriminator"](y_r, y_g)
            discriminator_total_loss += discriminator_losses[
                "{}_discriminator".format(key)
            ]

        for key, _ in self.discriminators.items():
            y_r, y_g, f_r, f_g = self.discriminators[key].forward(
                audio_gt.unsqueeze(1), audio_pred
            )
            generator_losses["{}_feature".format(key)] = self.criterions["feature"](
                f_r, f_g
            )
            generator_losses["{}_generator".format(key)], _ = self.criterions[
                "generator"
            ](y_g)
            generator_total_loss += generator_losses["{}_feature".format(key)]
            generator_total_loss += generator_losses["{}_generator".format(key)]

        if "mel" in self.criterions.keys():
            generator_losses["mel"] = self.criterions["mel"](audio_gt, audio_pred)
            generator_total_loss += generator_losses["mel"]
        if "mel" in self.criterions.keys():
            generator_losses["mel"] = self.criterions["mel"](audio_gt, audio_pred)
            generator_total_loss += generator_losses["mel"]

        if "wav" in self.criterions.keys():
            generator_losses["wav"] = self.criterions["wav"](audio_gt, audio_pred)
            generator_total_loss += generator_losses["wav"]
        if "wav" in self.criterions.keys():
            generator_losses["wav"] = self.criterions["wav"](audio_gt, audio_pred)
            generator_total_loss += generator_losses["wav"]

        if "amplitude" in self.criterions.keys():
            generator_losses["amplitude"] = self.criterions["amplitude"](
                logamp_gt, logamp_pred
            )
            generator_total_loss += generator_losses["amplitude"]

        if "phase" in self.criterions.keys():
            generator_losses["phase"] = self.criterions["phase"](pha_gt, pha_pred)
            generator_total_loss += generator_losses["phase"]

        if "consistency" in self.criterions.keys():
            generator_losses["consistency"] = self.criterions["consistency"](
                rea_gt,
                rea_pred,
                rea_pred_final,
                imag_gt,
                imag_pred,
                imag_pred_final,
            )
            generator_total_loss += generator_losses["consistency"]

        total_loss = discriminator_total_loss + generator_total_loss
        valid_losses.update(discriminator_losses)
        valid_losses.update(generator_losses)

        for item in valid_losses:
            valid_losses[item] = valid_losses[item].item()
        for item in valid_losses:
            valid_losses[item] = valid_losses[item].item()

        return total_loss.item(), valid_losses
        return total_loss.item(), valid_losses

    def _inference(self, eval_mel, eval_pitch=None, use_pitch=False):
        """Inference during training for test audios."""
        if use_pitch:
            eval_pitch = align_length(eval_pitch, eval_mel.shape[1])
            eval_audio = vocoder_inference(
                self.cfg,
                self.generator,
                torch.from_numpy(eval_mel).unsqueeze(0),
                f0s=torch.from_numpy(eval_pitch).unsqueeze(0).float(),
                device=next(self.generator.parameters()).device,
            ).squeeze(0)
        else:
            eval_audio = vocoder_inference(
                self.cfg,
                self.generator,
                torch.from_numpy(eval_mel).unsqueeze(0),
                device=next(self.generator.parameters()).device,
            ).squeeze(0)
        return eval_audio

    def _load_model(self, checkpoint_dir, checkpoint_path=None, resume_type="resume"):
        """Load model from checkpoint. If checkpoint_path is None, it will
        load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
        None, it will load the checkpoint specified by checkpoint_path. **Only use this
        method after** ``accelerator.prepare()``.
        """
        if checkpoint_path is None:
            ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
            ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
            checkpoint_path = ls[0]
        if resume_type == "resume":
            self.accelerator.load_state(checkpoint_path)
            self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
            self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
        elif resume_type == "finetune":
            accelerate.load_checkpoint_and_dispatch(
                self.accelerator.unwrap_model(self.generator),
                os.path.join(checkpoint_path, "pytorch_model.bin"),
            )
            for key, _ in self.discriminators.items():
                accelerate.load_checkpoint_and_dispatch(
                    self.accelerator.unwrap_model(self.discriminators[key]),
                    os.path.join(checkpoint_path, "pytorch_model.bin"),
                )
            self.logger.info("Load model weights for finetune SUCCESS!")
        else:
            raise ValueError("Unsupported resume type: {}".format(resume_type))
        return checkpoint_path

    def _count_parameters(self):
        result = sum(p.numel() for p in self.generator.parameters())
        for _, discriminator in self.discriminators.items():
            result += sum(p.numel() for p in discriminator.parameters())
        return result
